# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import numpy as np
import torch
import math
import random

from sklearn.linear_model import LinearRegression
from itertools import chain, combinations
from scipy.stats import f as fdist
from scipy.stats import ttest_ind

from torch.autograd import grad

import scipy.optimize

import matplotlib
import matplotlib.pyplot as plt

from torch import nn, optim, autograd
import pdb
import torch
#from torchvision import datasets



def pretty(vector):
    vlist = vector.view(-1).tolist()
    return "[" + ", ".join("{:+.4f}".format(vi) for vi in vlist) + "]"


class InvariantRiskMinimization(object):
    def __init__(self, environments, args):
        torch.manual_seed(args["seed"])
        best_reg = 0
        best_err = 1e6

        x_val = environments[-1][0]
        y_val = environments[-1][1]

        for reg in [0, 1e-4, 1e-1]:
            self.train(environments[:-1], args, reg=reg)
            err = (x_val @ self.solution() - y_val).pow(2).mean().item()

            if args["verbose"]:
                print("IRM (reg={:.3f}) has {:.3f} validation error.".format(
                    reg, err))

            if err < best_err:
                best_err = err
                best_reg = reg
                best_phi = self.phi.clone()
        self.phi = best_phi

    def train(self, environments, args, reg=0):
        dim_x = environments[0][0].size(1)

        self.phi = torch.nn.Parameter(torch.eye(dim_x, dim_x))
        self.w = torch.ones(dim_x, 1)
        self.w.requires_grad = True

        opt = torch.optim.Adam([self.phi], lr=args["lr"])
        loss = torch.nn.MSELoss()

        for iteration in range(args["n_iterations"]):
            penalty = 0
            error = 0
            for x_e, y_e in environments:
                error_e = loss(x_e @ self.phi @ self.w, y_e)
                penalty += grad(error_e, self.w,
                                create_graph=True)[0].pow(2).mean()
                error += error_e

            opt.zero_grad()
            (reg * error + (1 - reg) * penalty).backward()
            opt.step()

            if args["verbose"] and iteration % args["print_rate"] == 0:
                w_str = pretty(self.solution())
                print("{:05d} | {:.5f} | {:.5f} | {:.5f} | {}".format(iteration,
                                                                      reg,
                                                                      error,
                                                                      penalty,
                                                                      w_str))

    def solution(self):
        return (self.phi @ self.w).view(-1, 1)
    
    def para(self):
        return self.phi, self.w
#%%

class InvariantCausalPrediction(object):
    def __init__(self, environments, args):
        torch.manual_seed(args["seed"])
        self.coefficients = None
        self.alpha = args["alpha"]

        x_all = []
        y_all = []
        e_all = []

        for e, (x, y) in enumerate(environments):
            x_all.append(x.numpy())
            y_all.append(y.numpy())
            e_all.append(np.full(x.shape[0], e))

        x_all = np.vstack(x_all)
        y_all = np.vstack(y_all)
        e_all = np.hstack(e_all)

        dim = x_all.shape[1]

        accepted_subsets = []
        for subset in self.powerset(range(dim)):
            if len(subset) == 0:
                continue

            x_s = x_all[:, subset]
            reg = LinearRegression(fit_intercept=False).fit(x_s, y_all)

            p_values = []
            for e in range(len(environments)):
                e_in = np.where(e_all == e)[0]
                e_out = np.where(e_all != e)[0]

                res_in = (y_all[e_in] - reg.predict(x_s[e_in, :])).ravel()
                res_out = (y_all[e_out] - reg.predict(x_s[e_out, :])).ravel()

                p_values.append(self.mean_var_test(res_in, res_out))

            # TODO: Jonas uses "min(p_values) * len(environments) - 1"
            p_value = min(p_values) * len(environments)

            if p_value > self.alpha:
                accepted_subsets.append(set(subset))
                if args["verbose"]:
                    print("Accepted subset:", subset)

        if len(accepted_subsets):
            accepted_features = list(set.intersection(*accepted_subsets))
            if args["verbose"]:
                print("Intersection:", accepted_features)
            self.coefficients = np.zeros(dim)

            if len(accepted_features):
                x_s = x_all[:, list(accepted_features)]
                reg = LinearRegression(fit_intercept=False).fit(x_s, y_all)
                self.coefficients[list(accepted_features)] = reg.coef_

            self.coefficients = torch.Tensor(self.coefficients)
        else:
            self.coefficients = torch.zeros(dim)

    def mean_var_test(self, x, y):
        pvalue_mean = ttest_ind(x, y, equal_var=False).pvalue
        pvalue_var1 = 1 - fdist.cdf(np.var(x, ddof=1) / np.var(y, ddof=1),
                                    x.shape[0] - 1,
                                    y.shape[0] - 1)

        pvalue_var2 = 2 * min(pvalue_var1, 1 - pvalue_var1)

        return 2 * min(pvalue_mean, pvalue_var2)

    def powerset(self, s):
        return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))

    def solution(self):
        return self.coefficients.view(-1, 1)
    
    
   #%% 

class REX_var(object):

    def __init__(self, environments, args):
        torch.manual_seed(args["seed"])
        best_reg = 0
        best_err = 1e6

        self.args = args

        x_val = environments[-1][0]
        y_val = environments[-1][1]

        reg = args["reg"]
        if reg == -1:
            #regs = [0, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1]
            regs = [0, 1e-4, 1e-1]
        else:
            regs = [reg]

        for reg in regs:
            self.train(environments, args, reg=reg)
            err = (x_val @ self.solution() - y_val).pow(2).mean().item()

            if 1:#args["verbose"]:
                print("REXv21 (reg={:.3f}) has {:.3f} validation error.".format(
                    reg, err))

            if err < best_err:
                best_err = err
                best_reg = reg
                best_phi = self.phi.clone()
        self.phi = best_phi

    def compute_REX_penalty(self, loss1, loss2, clip):
        if self.args["penalty"] == "MAE":
            return ((torch.clamp(loss1, clip) - torch.clamp(loss2, clip))**2)**.5
        elif self.args["penalty"] == "MSE":
            return ((torch.clamp(loss1, clip) - torch.clamp(loss2, clip))**2)
        else:
            assert False # TODO

    def train(self, environments, args, reg=0.5, use_cuda=False):
        if 1:#args["baseline"] == "var(Y)" or args["baseline"] == "var(Y)-var(Y^)":
            baselines = [torch.var(env[1]) for env in environments]
        self.baselines = baselines

        self.rs = []
        self.r1s = []
        self.r2s = []
        self.r3s = []
        # TODO:
        self.rs_clipped = []
        self.r1s_clipped = []
        self.r2s_clipped = []
        self.r3s_clipped = []
        #
        self.phis = []

        dim_x = environments[0][0].size(1)
        self.dim_x = dim_x

        self.w = torch.ones(dim_x, 1)
        self.w.requires_grad = False
        if args["use_IRM_parametrization"] == 1:
            #assert False # TODO: (what do I need to do here???)
            self.phi = torch.nn.Parameter(torch.eye(dim_x,dim_x) + args["init_noise"]*torch.randn(dim_x, dim_x))
            self.weights = self.phi @ self.w
        else:
            self.phi = torch.nn.Parameter(torch.ones(dim_x,1) + args["init_noise"]*torch.randn(dim_x,1))
            self.weights = self.phi

        #opt = torch.optim.SGD([self.phi], lr=0.01, momentum=0.9) #args["lr"])
        opt = torch.optim.Adam([self.phi], lr=args["lr"])
        MSE = torch.nn.MSELoss()
         
        for iteration in range(args["n_iterations"]):

            clip = args["clip"]
            if iteration > args["clip_until"]:
                clip = -np.inf

            if args["use_IRM_parametrization"] == 1:
                weights = self.phi @ torch.ones(dim_x, 1)
            else:
                weights = self.phi
            penalty = 0
            error = 0
                       
            x_e_1, y_e_1 = environments[0]
            x_e_2, y_e_2 = environments[1]
            yhat1 = x_e_1 @ weights
            yhat2 = x_e_2 @ weights
            r1 = MSE(yhat1, y_e_1) - baselines[0]
            r2 = MSE(yhat2, y_e_2) - baselines[1]
            if 1:#elif args["baseline"] == "var(Y^)":
                r1 -= torch.var(yhat1)
                r2 -= torch.var(yhat2)
            # validation
            x_e_3, y_e_3 = environments[2]
            r3 = MSE(x_e_3 @ weights, y_e_3) - baselines[2]

            if args["print_clipped"]:
                r1 = torch.clamp(r1, clip)
                r2 = torch.clamp(r2, clip)
                r3 = torch.clamp(r3, clip)

            # learning curves
            self.rs.append(r1+r2)
            self.r1s.append(r1)
            self.r2s.append(r2)
            self.r3s.append(r3)
            #self.phis.append(self.weights)
            self.phis.append(weights)


            # COMPUTE LOSS
            loss = 0.0
            penalty = (self.compute_REX_penalty(r1, r2, clip))
            error = r1 + r2 # N.B.: NOT USING VALIDATION ENVIRONMENT
            loss += error
            loss += args["L2_reg"] * self.phi.pow(2).sum().sqrt()
            loss = reg * loss + (1 - reg) * penalty

            opt.zero_grad()
            loss.backward()
            opt.step()

            if args["verbose"] and iteration % args["print_rate"] == 0:
                w_str = pretty(self.solution())
                w_norm = self.phi.pow(2).sum().sqrt()
                print("it: {:05d} | reg: {:.5f} | loss: {:.5f} | r1: {:.5f} | r2: {:.5f} | w_norm: {:.5f} | {}".format(iteration,
                                                                        reg,
                                                                        loss,
                                                                        r1,
                                                                        r2,
                                                                        w_norm,
                                                                        w_str))
    def solution(self):
        if self.args["use_IRM_parametrization"] == 1:
            return self.phi@ self.w
        else:
            return self.phi
        
    def para(self):
        if self.args["use_IRM_parametrization"] == 1:
            return self.phi, self.w
        else:
            return self.phi



#%%

class EmpiricalRiskMinimizer(object):
    def __init__(self, environments, args):
        x_all = torch.cat([x for (x, y) in environments]).numpy()
        y_all = torch.cat([y for (x, y) in environments]).numpy()
        p=x_all.shape[1]

        w = LinearRegression(fit_intercept=False).fit(x_all, y_all).coef_
        self.w = torch.Tensor(w).view(-1, 1)
        self.phi=torch.eye(p)

    def solution(self):
        return self.w
    
    def para(self):
        return self.phi, self.w

#%%

class LISA(object):
    def __init__(self, environments, args):
        self.phi = torch.eye(2) 
        self.args = args
        self.environments = environments
        self.mixed_environments = self.create_mixed_environments(environments)
        self.train(self.mixed_environments, args)


    def create_mixed_environments(self, environments):
        mixed_data = []
        mixed_count_same_env = 0
        mixed_count_diff_env = 0
        all_data = [(x, y, k) for k, (x, y) in enumerate(environments)]

        alpha, beta = self.args.get('alpha', 0.5), self.args.get('beta', 0.5)

        random.shuffle(all_data)

        for (x_i, y_i, e_i) in all_data:
            indices_i = list(range(x_i.shape[0]))
            for (x_j, y_j, e_j) in all_data:
                indices_j = list(range(x_j.shape[0]))
                
                for idx_i, idx_j in zip(indices_i, indices_j):
                    distance = torch.sqrt(y_i[idx_i] - y_j[idx_j] ** 2).item()
                    
                    if e_i == e_j and distance > 1:  # same environment
                        lambda_mix = torch.distributions.Beta(alpha, beta).sample().item()
                        x_mixed = lambda_mix * (x_i[idx_i]) + (1 - lambda_mix) * (x_j[idx_j])
                        y_mixed = lambda_mix * y_i[idx_i] + (1 - lambda_mix) * y_j[idx_j]
                        mixed_data.append((x_mixed, y_mixed))
                        mixed_count_same_env += 1

                    elif e_i != e_j and distance <= 1:  # across environment
                        lambda_mix = torch.distributions.Beta(alpha, beta).sample().item()
                        x_mixed = lambda_mix * (x_i[idx_i]) + (1 - lambda_mix) * (x_j[idx_j])
                        y_mixed = lambda_mix * y_i[idx_i] + (1 - lambda_mix) * y_j[idx_j]
                        mixed_data.append((x_mixed, y_mixed))
                        mixed_count_diff_env += 1

                    indices_i.remove(idx_i)
                    indices_j.remove(idx_j)
                    if not indices_i or not indices_j:
                        break 

        return mixed_data


    def train(self, environments, args):
        dim_x = 2
        self.w = torch.ones(dim_x, 1, requires_grad=True)
        opt = optim.Adam([self.phi], lr=args["lr"])
        loss = nn.MSELoss()

        for iteration in range(args["n_iterations"]):
            total_loss = 0
            for x_e, y_e in environments:
                pred_y = x_e @ self.phi @ self.w  
                error = loss(pred_y, y_e)
                total_loss += error.item()

                opt.zero_grad()
                error.backward()
                opt.step()

            if args["verbose"] and iteration % 100 == 0:
                print(f"{iteration:05d} | Total Loss: {total_loss:.5f}")

    def solution(self):
        return (self.phi @ self.w).view(-1, 1)
    
    def para(self):
        return self.phi, self.w
